{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solving the Frozen Lake Problem with Policy Iteration\n",
    "\n",
    "We learned that in the frozen lake environment, our goal is to reach the goal state G from\n",
    "the starting state S without visiting the hole states H. Now, let's learn how to compute the optimal policy using the policy iteration method in the frozen lake environment.\n",
    "\n",
    "First, let's import the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's create the frozen lake environment using gym:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('FrozenLake-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "We learned that in the policy iteration, we compute the value function using the policy\n",
    "iteratively. Once we found the optimal value function then the policy which is used to\n",
    "compute the optimal value function will be the optimal policy.\n",
    "\n",
    "So, first, let's learn how to compute the value function using the policy. \n",
    "\n",
    "\n",
    "## Computing value function using policy\n",
    "\n",
    "This step is exactly the same as how we computed the value function in the value iteration\n",
    "method but with a small difference. Here we compute the value function using the policy\n",
    "but in the value iteration method, we compute the value function by taking the maximum\n",
    "over Q values. Now, let's learn how to define a function that computes the value function\n",
    "using the given policy.\n",
    "\n",
    "\n",
    "Let's define a function called `compute_value_function` which takes the policy as a\n",
    "parameter:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_value_function(policy):\n",
    "    \n",
    "    #now, let's define the number of iterations\n",
    "    num_iterations = 1000\n",
    "    \n",
    "    #define the threshold value\n",
    "    threshold = 1e-20\n",
    "    \n",
    "    #set the discount factor\n",
    "    gamma = 1.0\n",
    "    \n",
    "    #now, we will initialize the value table, with the value of all states to zero\n",
    "    value_table = np.zeros(env.observation_space.n)\n",
    "    \n",
    "    #for every iteration\n",
    "    for i in range(num_iterations):\n",
    "        \n",
    "        #update the value table, that is, we learned that on every iteration, we use the updated value\n",
    "        #table (state values) from the previous iteration\n",
    "        updated_value_table = np.copy(value_table)\n",
    "        \n",
    "        \n",
    "\n",
    "        #thus, for each state, we select the action according to the given policy and then we update the\n",
    "        #value of the state using the selected action as shown below\n",
    "        \n",
    "        #for each state\n",
    "        for s in range(env.observation_space.n):\n",
    "            \n",
    "            #select the action in the state according to the policy\n",
    "            a = policy[s]\n",
    "            \n",
    "            #compute the value of the state using the selected action\n",
    "            value_table[s] = sum([prob * (r + gamma * updated_value_table[s_]) \n",
    "                                        for prob, s_, r, _ in env.P[s][a]])\n",
    "            \n",
    "        #after computing the value table, that is, value of all the states, we check whether the\n",
    "        #difference between value table obtained in the current iteration and previous iteration is\n",
    "        #less than or equal to a threshold value if it is less then we break the loop and return the\n",
    "        #value table as an accurate value function of the given policy\n",
    "\n",
    "        if (np.sum((np.fabs(updated_value_table - value_table))) <= threshold):\n",
    "            break\n",
    "            \n",
    "    return value_table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now that we have computed the value function of the policy, let's see how to extract the\n",
    "policy from the value function. \n",
    "\n",
    "## Extracting policy from the value function\n",
    "\n",
    "This step is exactly the same as how we extracted policy from the value function in the\n",
    "value iteration method. Thus, similar to what we learned in the value iteration method, we\n",
    "define a function called `extract_policy` to extract a policy given the value function as\n",
    "shown below:\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_policy(value_table):\n",
    "    \n",
    "    #set the discount factor\n",
    "    gamma = 1.0\n",
    "     \n",
    "    #first, we initialize the policy with zeros, that is, first, we set the actions for all the states to\n",
    "    #be zero\n",
    "    policy = np.zeros(env.observation_space.n) \n",
    "    \n",
    "    #now, we compute the Q function using the optimal value function obtained from the\n",
    "    #previous step. After computing the Q function, we can extract policy by selecting action which has\n",
    "    #maximum Q value. Since we are computing the Q function using the optimal value\n",
    "    #function, the policy extracted from the Q function will be the optimal policy. \n",
    "    \n",
    "    #As shown below, for each state, we compute the Q values for all the actions in the state and\n",
    "    #then we extract policy by selecting the action which has maximum Q value.\n",
    "    \n",
    "    #for each state\n",
    "    for s in range(env.observation_space.n):\n",
    "        \n",
    "        #compute the Q value of all the actions in the state\n",
    "        Q_values = [sum([prob*(r + gamma * value_table[s_])\n",
    "                             for prob, s_, r, _ in env.P[s][a]]) \n",
    "                                   for a in range(env.action_space.n)] \n",
    "                \n",
    "        #extract policy by selecting the action which has maximum Q value\n",
    "        policy[s] = np.argmax(np.array(Q_values))        \n",
    "    \n",
    "    return policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Putting it all together\n",
    "\n",
    "First, let's define a function called `policy_iteration` which takes the environment as a\n",
    "parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy_iteration(env):\n",
    "    \n",
    "    #set the number of iterations\n",
    "    num_iterations = 1000\n",
    "    \n",
    "    #we learned that in the policy iteration method, we begin by initializing a random policy.\n",
    "    #so, we will initialize the random policy which selects the action 0 in all the states\n",
    "    policy = np.zeros(env.observation_space.n)  \n",
    "    \n",
    "    #for every iteration\n",
    "    for i in range(num_iterations):\n",
    "        #compute the value function using the policy\n",
    "        value_function = compute_value_function(policy)\n",
    "        \n",
    "        #extract the new policy from the computed value function\n",
    "        new_policy = extract_policy(value_function)\n",
    "           \n",
    "        #if the policy and new_policy are same then break the loop\n",
    "        if (np.all(policy == new_policy)):\n",
    "            break\n",
    "        \n",
    "        #else, update the current policy to new_policy\n",
    "        policy = new_policy\n",
    "        \n",
    "    return policy\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now, let's learn how to perform policy iteration and find the optimal policy in the frozen\n",
    "lake environment. \n",
    "\n",
    "So, we just feed the frozen lake environment to our `policy_iteration`\n",
    "function as shown below and get the optimal policy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimal_policy = policy_iteration(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can print the optimal policy: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 3. 3. 3. 0. 0. 0. 0. 3. 1. 0. 0. 0. 2. 1. 0.]\n"
     ]
    }
   ],
   "source": [
    "print(optimal_policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can observe, our optimal policy tells us to perform the correct action in each\n",
    "state. Thus, we learned how to perform the policy iteration method to compute the optimal\n",
    "policy. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
